// MIT License
//
// # Copyright (c) 2022 Ignacio Vizzo, Cyrill Stachniss, University of Bonn
// * # RGB Part is from https://github.com/paucarre in this commit
// *
// https://github.com/PRBonn/vdbfusion/tree/94c0254ece6f4f9f13164cf6a3f8eb6fd324209d/src/vdbfusion/vdbfusion
// *
// * This repo Maintainer: Kin ZHANG (kin_eng@163.com)
// *
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

#include "VDBVolume.h"

// OpenVDB
#include <openvdb/Types.h>
#include <openvdb/math/DDA.h>
#include <openvdb/math/Ray.h>
#include <openvdb/openvdb.h>

#include <Eigen/Core>
#include <algorithm>
#include <cmath>
#include <functional>
#include <iostream>
#include <memory>
#include <vector>

namespace {

float ComputeSDF(const Eigen::Vector3d &origin, const Eigen::Vector3d &point,
                 const Eigen::Vector3d &voxel_center) {
  const Eigen::Vector3d v_voxel_origin = voxel_center - origin;
  const Eigen::Vector3d v_point_voxel = point - voxel_center;
  const double dist = v_point_voxel.norm();
  const double proj = v_voxel_origin.dot(v_point_voxel);
  const double sign = proj / std::abs(proj);
  return static_cast<float>(sign * dist);
}

Eigen::Vector3d GetVoxelCenter(const openvdb::Coord &voxel,
                               const openvdb::math::Transform &xform) {
  const float voxel_size = xform.voxelSize()[0];
  openvdb::math::Vec3d v_wf = xform.indexToWorld(voxel) + voxel_size / 2.0;
  return {v_wf.x(), v_wf.y(), v_wf.z()};
}

openvdb::Vec3i BlendColors(const openvdb::Vec3i &color1, float weight1,
                           const openvdb::Vec3i &color2, float weight2) {
  float weight_sum = weight1 + weight2;
  weight1 /= weight_sum;
  weight2 /= weight_sum;
  return {static_cast<int>(round(color1[0] * weight1 + color2[0] * weight2)),
          static_cast<int>(round(color1[1] * weight1 + color2[1] * weight2)),
          static_cast<int>(round(color1[2] * weight1 + color2[2] * weight2))};
}

}  // namespace

namespace vdbfusion {

VDBVolume::VDBVolume(float voxel_size, float sdf_trunc,
                     bool space_carving /* = false*/)
    : voxel_size_(voxel_size),
      sdf_trunc_(sdf_trunc),
      space_carving_(space_carving) {
  tsdf_ = openvdb::FloatGrid::create(sdf_trunc_);
  tsdf_->setName("D(x): signed distance grid");
  tsdf_->setTransform(
      openvdb::math::Transform::createLinearTransform(voxel_size_));
  tsdf_->setGridClass(openvdb::GRID_LEVEL_SET);

  weights_ = openvdb::FloatGrid::create(0.0f);
  weights_->setName("W(x): weights grid");
  weights_->setTransform(
      openvdb::math::Transform::createLinearTransform(voxel_size_));
  weights_->setGridClass(openvdb::GRID_UNKNOWN);

  colors_ = openvdb::Vec3IGrid::create(openvdb::Vec3I(0, 0, 0));
  colors_->setName("C(x): colors grid");
  colors_->setTransform(
      openvdb::math::Transform::createLinearTransform(voxel_size_));
  colors_->setGridClass(openvdb::GRID_UNKNOWN);
}

void VDBVolume::Integrate(
    const std::vector<Eigen::Vector3d> &points, const Eigen::Vector3d &origin,
    const std::function<float(float)> &weighting_function) {
  if (points.empty()) {
    std::cerr << "PointCloud provided is empty\n";
    return;
  }

  // Get some variables that are common to all rays
  const openvdb::math::Transform &xform = tsdf_->transform();
  const openvdb::Vec3R eye(origin.x(), origin.y(), origin.z());

  // Get the "unsafe" version of the grid acessors
  auto tsdf_acc = tsdf_->getUnsafeAccessor();
  auto weights_acc = weights_->getUnsafeAccessor();

  // Launch an for_each execution, use std::execution::par to parallelize this
  // region
  std::for_each(points.cbegin(), points.cend(), [&](const auto &point) {
    // Get the direction from the sensor origin to the point and normalize it
    const Eigen::Vector3d direction = point - origin;
    openvdb::Vec3R dir(direction.x(), direction.y(), direction.z());
    dir.normalize();

    // Truncate the Ray before and after the source unless space_carving_ is
    // specified.
    const auto depth = static_cast<float>(direction.norm());
    const float t0 = space_carving_ ? 0.0f : depth - sdf_trunc_;
    const float t1 = depth + sdf_trunc_;

    // Create one DDA per ray(per thread), the ray must operate on voxel grid
    // coordinates.
    const auto ray =
        openvdb::math::Ray<float>(eye, dir, t0, t1).worldToIndex(*tsdf_);
    openvdb::math::DDA<decltype(ray)> dda(ray);
    do {
      const auto voxel = dda.voxel();
      const auto voxel_center = GetVoxelCenter(voxel, xform);
      const auto sdf = ComputeSDF(origin, point, voxel_center);
      if (sdf > -sdf_trunc_) {
        const float tsdf = std::min(sdf_trunc_, sdf);
        const float weight = weighting_function(sdf);
        const float last_weight = weights_acc.getValue(voxel);
        const float last_tsdf = tsdf_acc.getValue(voxel);
        const float new_weight = weight + last_weight;
        const float new_tsdf =
            (last_tsdf * last_weight + tsdf * weight) / (new_weight);
        tsdf_acc.setValue(voxel, new_tsdf);
        weights_acc.setValue(voxel, new_weight);
      }
    } while (dda.step());
  });
}

void VDBVolume::Integrate(
    openvdb::FloatGrid::Ptr grid,
    const std::function<float(float)> &weighting_function) {
  for (auto iter = grid->cbeginValueOn(); iter.test(); ++iter) {
    const auto &sdf = iter.getValue();
    const auto &voxel = iter.getCoord();
    this->UpdateTSDF(sdf, voxel, weighting_function);
  }
}
void VDBVolume::UpdateTSDF(
    const float &sdf, const openvdb::Coord &voxel,
    const std::function<float(float)> &weighting_function) {
  using AccessorRW = openvdb::tree::ValueAccessorRW<openvdb::FloatTree>;
  if (sdf > -sdf_trunc_) {
    AccessorRW tsdf_acc = AccessorRW(tsdf_->tree());
    AccessorRW weights_acc = AccessorRW(weights_->tree());
    const float tsdf = std::min(sdf_trunc_, sdf);
    const float weight = weighting_function(sdf);
    const float last_weight = weights_acc.getValue(voxel);
    const float last_tsdf = tsdf_acc.getValue(voxel);
    const float new_weight = weight + last_weight;
    const float new_tsdf =
        (last_tsdf * last_weight + tsdf * weight) / (new_weight);
    tsdf_acc.setValue(voxel, new_tsdf);
    weights_acc.setValue(voxel, new_weight);
  }
}
void VDBVolume::Integrate(
    const std::vector<Eigen::Vector3d> &points,
    const std::vector<openvdb::Vec3i> &colors, const Eigen::Vector3d &origin,
    const std::function<float(float)> &weighting_function) {
  if (points.empty()) {
    std::cerr << "PointCloud provided is empty\n";
    return;
  }
  bool has_colors = !colors.empty();
  if (has_colors && points.size() != colors.size()) {
    std::cerr
        << "PointCloud and ColorCloud provided do not have the same size\n";
    return;
  }

  // Get some variables that are common to all rays
  const openvdb::math::Transform &xform = tsdf_->transform();
  const openvdb::Vec3R eye(origin.x(), origin.y(), origin.z());

  // Get the "unsafe" version of the grid accessors
  auto tsdf_acc = tsdf_->getUnsafeAccessor();
  auto weights_acc = weights_->getUnsafeAccessor();
  auto colors_acc = colors_->getUnsafeAccessor();

  // Iterate points
  for (size_t i = 0; i < points.size(); ++i) {
    // Get the direction from the sensor origin to the point and normalize it
    const auto point = points[i];
    const Eigen::Vector3d direction = point - origin;
    openvdb::Vec3R dir(direction.x(), direction.y(), direction.z());
    dir.normalize();

    // Truncate the Ray before and after the source unless space_carving_ is
    // specified.
    const auto depth = static_cast<float>(direction.norm());
    const float t0 = space_carving_ ? 0.0f : depth - sdf_trunc_;
    const float t1 = depth + sdf_trunc_;

    // Create one DDA per ray(per thread), the ray must operate on voxel grid
    // coordinates.
    const auto ray =
        openvdb::math::Ray<float>(eye, dir, t0, t1).worldToIndex(*tsdf_);
    openvdb::math::DDA<decltype(ray)> dda(ray);
    do {
      const auto voxel = dda.voxel();
      const auto voxel_center = GetVoxelCenter(voxel, xform);
      const auto sdf = ComputeSDF(origin, point, voxel_center);
      if (sdf > -sdf_trunc_) {
        const float tsdf = std::min(sdf_trunc_, sdf);
        const float weight = weighting_function(sdf);
        const float last_weight = weights_acc.getValue(voxel);
        const float last_tsdf = tsdf_acc.getValue(voxel);
        const float new_weight = weight + last_weight;
        const float new_tsdf =
            (last_tsdf * last_weight + tsdf * weight) / (new_weight);
        tsdf_acc.setValue(voxel, new_tsdf);
        weights_acc.setValue(voxel, new_weight);
        if (has_colors) {
          const auto color = colors_acc.getValue(voxel);
          openvdb::Vec3i new_color =
              BlendColors(color, last_weight, colors[i], weight);
          colors_acc.setValue(voxel, new_color);
        }
      }
    } while (dda.step());
  }
}

openvdb::FloatGrid::Ptr VDBVolume::Prune(float min_weight) const {
  const auto weights = weights_->tree();
  const auto tsdf = tsdf_->tree();
  const auto background = sdf_trunc_;
  openvdb::FloatGrid::Ptr clean_tsdf = openvdb::FloatGrid::create(sdf_trunc_);
  clean_tsdf->setName("D(x): Pruned signed distance grid");
  clean_tsdf->setTransform(
      openvdb::math::Transform::createLinearTransform(voxel_size_));
  clean_tsdf->setGridClass(openvdb::GRID_LEVEL_SET);
  clean_tsdf->tree().combine2Extended(
      tsdf, weights, [=](openvdb::CombineArgs<float> &args) {
        if (args.aIsActive() && args.b() > min_weight) {
          args.setResult(args.a());
          args.setResultIsActive(true);
        } else {
          args.setResult(background);
          args.setResultIsActive(false);
        }
      });
  return clean_tsdf;
}
}  // namespace vdbfusion